import shutil
def count_data():
    with open("./data/custom/test.txt",'r') as f:
        file=f.readlines()
    with open("data/custom/classes.names",'r') as f:
        index_list=f.readlines()
    print(index_list)
    label_dict={}
    for path in file:
        image_path=path.replace("\n","")
        label_path=image_path.replace("images","labels").replace("jpg","txt")
        with open(label_path, 'r') as f:
            labelfile = f.readlines()
        for label in labelfile:
            index=int(label.split(" ")[0])
            label_name=index_list[index].replace("\n","")
            if label_name in ['pl15','pl90']:
                print(labelfile)
                flag=True
            if label_dict.get(label_name)==None:
                label_dict[label_name]=1
            else:
                label_dict[label_name] += 1
    for k in label_dict.keys():
        print(index_list.index(k+"\n"),k,label_dict[k])
def my_data():
    re_all_data_file=open("./data/custom/re_test.txt",'w')
    with open("./data/custom/test.txt",'r') as f:
        file=f.readlines()
    with open("data/custom/classes.names",'r') as f:
        index_list=f.readlines()
    with open("data/custom/re_classes.names",'r') as f:
        re_index_list=f.readlines()
    print(index_list)
    label_dict={}
    for path in file:
        image_path=path.replace("\n","")
        label_path=image_path.replace("images","labels_0").replace("jpg","txt")
        label_re_write_data=[]
        flag=False
        with open(label_path, 'r') as f:
            labelfile = f.readlines()
        for label in labelfile:
            index=int(label.split(" ")[0])

            label_name=index_list[index].replace("\n","")
            if label_name in ['pl15','pl90']:
                flag=True
            else:
                re_index = re_index_list.index(index_list[index])
                re_label = str(re_index) + " " + str(" ".join(label.split(" ")[1:]))
                label_re_write_data.append(re_label)
        if flag:
            print(label_re_write_data)
            print("-"*100)
        if label_re_write_data!=[]:
            re_all_data_file.write(path)
            with open(label_path.replace("labels_0","labels"), 'w') as f:
                f.writelines(label_re_write_data)
def recover_data():
    re_train=open("./data/custom/re_train.txt","w")
    re_test=open("./data/custom/test.txt","w")
    with open("./data/custom/re_all_data.txt", 'r') as f:
        file = f.readlines()
    with open("data/custom/classes.names", 'r') as f:
        index_list = f.readlines()
    print(index_list)
    label_dict = {}
    for path in file:
        image_path = path.replace("\n", "")
        label_path = image_path.replace("images", "re_labels").replace("jpg", "txt")
        with open(label_path, 'r') as f:
            labelfile = f.readlines()
        for label in labelfile:
            index = int(label.split(" ")[0])
            label_name = index_list[index].replace("\n", "")
            if label_name in ['pl15', 'pl90']:
                print(labelfile)
                flag = True
            if label_dict.get(label_name) == None:
                label_dict[label_name] = 1
            else:
                label_dict[label_name] += 1
    for k in label_dict.keys():
        print(index_list.index(k + "\n"), k, label_dict[k])


    with open("./data/custom/re_all_data.txt", 'r') as f:
        file = f.readlines()
    with open("data/custom/classes.names", 'r') as f:
        index_list = f.readlines()
    print(index_list)
    label_dict1 = {}
    for path in file:
        image_path = path.replace("\n", "")
        label_path = image_path.replace("images", "re_labels").replace("jpg", "txt")
        label_re_write_data=[]
        re_list=[]
        with open(label_path, 'r') as f:
            labelfile = f.readlines()
        for label in labelfile:
            index = int(label.split(" ")[0])
            label_name = index_list[index].replace("\n", "")
            re_list.append(label_name)
            if label_dict1.get(label_name) == None:
                label_dict1[label_name] = 1
            else:
                label_dict1[label_name] += 1
        for i in re_list:
            if label_dict1[i]!=None:
                if label_dict1[i]<=int(label_dict[i]*(7/10)):
                    path0=path
                    label_path0=label_path.replace("re_labels", "re0_labels")
                    # if "test" in path:
                    #     path0=path.replace("test","train")
                    #     label_path0=label_path0.replace("test","train")
                    re_train.write(path0)
                    with open(label_path0, 'w') as f:
                        f.writelines(labelfile)
                    break
                else:
                    path0 = path
                    label_path0 = label_path.replace("re_labels", "re0_labels")
                    # if "train" in path:
                    #     path0 = path.replace("train", "test")
                    #     label_path0 = label_path0.replace("train", "test")
                    re_test.write(path0)
                    with open(label_path0, 'w') as f:
                        f.writelines(labelfile)
                    break
if __name__ == '__main__':
    count_data()